import os
import pickle

import numpy as np
from sklearn import preprocessing
from sklearn.utils import resample
from sklearn.linear_model import LogisticRegression
from sklift.datasets import fetch_criteo
from sklift.models import TwoModels


if __name__ == '__main__':

    dataset = fetch_criteo(target_col='visit',
                           treatment_col='treatment',
                           data_home='datasets')

    min_max_scaler = preprocessing.MinMaxScaler()
    X = min_max_scaler.fit_transform(dataset.data)
    treat = dataset.treatment.values.astype(int)
    y = dataset.target.values.astype(int)

    n_visit = len(X[y == 1])
    X_notvisit_subsampled, treat_notvisit_subsampled = resample(
        X[y == 0], treat[y == 0], n_samples=n_visit*3, replace=False, random_state=42)

    X_subsampled = np.concatenate([X_notvisit_subsampled, X[y == 1]])
    treat_subsampled = np.concatenate([treat_notvisit_subsampled, treat[y == 1]])
    visit_subsampled = np.concatenate([np.zeros(n_visit*3), np.ones(n_visit)])

    save_dir_data = '../save/criteo/data/'
    if not os.path.exists(save_dir_data):
        os.makedirs(save_dir_data)

    np.save(os.path.join(save_dir_data, 'X_subsampled.npy'), X_subsampled)
    np.save(os.path.join(save_dir_data, 'treat_subsampled.npy'), treat_subsampled)
    np.save(os.path.join(save_dir_data, 'visit_subsampled.npy'), visit_subsampled)

    model_treat_lg = LogisticRegression(max_iter=1000)
    model_control_lg = LogisticRegression(max_iter=1000)
    tm_lg = TwoModels(estimator_trmnt=model_treat_lg,
                      estimator_ctrl=model_control_lg,
                      method='vanilla')
    tm_lg = tm_lg.fit(X_subsampled, visit_subsampled, treat_subsampled)

    save_dir_model = '../save/criteo/model/'
    if not os.path.exists(save_dir_model):
        os.makedirs(save_dir_model)

    filename_treat = os.path.join(save_dir_model, 'model_treat_subsample_lg.sav')
    filename_control = os.path.join(save_dir_model, 'model_control_subsample_lg.sav')

    pickle.dump(model_treat_lg, open(filename_treat, 'wb'))
    pickle.dump(model_control_lg, open(filename_control, 'wb'))

    visit_proba_treat = model_treat_lg.predict_proba(X_subsampled)[:, 1]
    visit_proba_control = model_control_lg.predict_proba(X_subsampled)[:, 1]
    visit_proba = np.vstack([visit_proba_control, visit_proba_treat]).T
    np.save(os.path.join(save_dir_data, 'visit_proba_subsampled.npy'), visit_proba)
